library(rio)
library(rstan)
library(ggplot2)
library(docopt)

"Usage: latent_variable_model.R [-e] [-s] [-g] [-a] [-l] [-c] [-i] [-b] [-p PROCESS_ID] [--iter N] [--warmup N] [--thin N] [--chains N]

-e               exponential gamma
-s               standardize years
-g               gamma prior on slope coefficients for continuous variables
-a               control for age
-c               include posts count scale instead of dummies
-l               take logs of years (actually, asinh)
-i               no priors on betas
-b               only include those found important in Random Forest analysis (manual)
--iter N         number of iterations in total [default: 8000]
--warmup N       number of iterations in warmup [default: 2000]
--thin N         amount of thinning [default: 5]
--chains N       number of chains [default: 1]
-p id            process ID [default: TEST]" -> doc

settings <- docopt(doc)

standardizeYears <- settings$s
logYears <- settings$l
expGamma <- settings$e
gammaPrior <- settings$g
inclAge <- settings$a
postsScale <- settings$c
skipPriors <- settings$i
important <- settings$b

nIterations <- as.integer(settings$iter)
nWarmup <- as.integer(settings$warmup)
nThinning <- as.integer(settings$thin)
nChains <- as.integer(settings$chains)

processID <- settings$p


polex <- import("../estimation_file_with_polexp.dta")


# imputing mean age for Graanoogst
polex$entryage[polex$rule_unique == "Graanoogst"] <- round(mean(polex$entryage, na.rm = TRUE), 0)

if (!important) {
  if (!postsScale) {
    notDummies <- c("experienceyears", "formalexperyears", "prevtimesinoffice")
    dummies <- c("polTopPostOrMinistry", "polAnyExp1", "polAnyExp2", "polAnyExp3")
  } else {
    notDummies <- c("experienceyears", "formalexperyears", "prevtimesinoffice", "postsCountScale")
    dummies <- c()
  }
} else { # important
  if (!postsScale) {
    notDummies <- c("experienceyears", "formalexperyears")
    dummies <- c("polAnyExp1")
  }
}

data <- polex[, c(notDummies, dummies, "entryage")]  # create data set with only selected variables
data <- data[nonMiss <- complete.cases(data), ]      # remove all missing data listwise

apply(data, 2, summary)

modelBlock <- parametersBlock <- transformedParametersBlock <- eqs <- NULL

modelBlock <- c(modelBlock, if (!expGamma) "gamma ~ normal(0, 1)" else "theta ~ normal(0,1)")

j <- 0

for (i in 1:length(dummies)) {

  v <- dummies[i]

  eqs <- c(eqs, sprintf("X1[i,%d] ~ bernoulli_logit(I%s + B%s * gamma[i]%s)", i, v, v, (if (inclAge) sprintf(" + Bage%d * age[i]", j) else "")))

  parametersBlock <- c(parametersBlock, sprintf("real I%s", v), sprintf("real B%s", v))
  
  if (!skipPriors)
    modelBlock <- c(modelBlock, sprintf("I%s ~ normal(0, 1)", v), sprintf("B%s ~ normal(0, 1)", v))
  
  if (inclAge) {
    parametersBlock <- c(parametersBlock, sprintf("real Bage%d", j))
    
    if (!skipPriors)
      modelBlock <- c(modelBlock, sprintf("Bage%d ~ normal(1, 1)", j))
  }
  
  j <- j + 1
}

for (i in 1:length(notDummies)) {

  v <- notDummies[i]

  if (standardizeYears) {
    data[, v] <- (data[, v] - mean(data[, v], na.rm = TRUE)) / (2 * sd(data[, v], na.rm = TRUE))
  } else if (logYears) {
    data[, v] <- asinh(data[, v])
  }
    
  if (standardizeYears | logYears) {
    parametersBlock <- c(parametersBlock, sprintf("real S%s", v))
    modelBlock <- c(modelBlock, sprintf("S%s ~ gamma(2,1)", v))
    
    type = "real"
    distr = "normal"
  } else {
    type = "int"
    distr = "poisson_log"
  }

  eqs <- c(eqs, sprintf("X2[i,%d] ~ %s(I%s + B%s * gamma[i]%s%s)", i, distr, v, v, ifelse(distr == "normal", sprintf(", S%s", v), ""), (if (inclAge) sprintf(" + Bage%d * age[i]", j) else "")))

  parametersBlock <- c(parametersBlock, sprintf("real I%s", v), if (!gammaPrior) sprintf("real B%s", v) else sprintf("real<lower=0> B%s", v))
  
  if (!skipPriors)
    modelBlock <- c(modelBlock, sprintf("I%s ~ normal(0, 1)", v), if (!gammaPrior) sprintf("B%s ~ normal(0, 1)", v) else sprintf("B%s ~ gamma(.001, .001)", v))
  
  if (inclAge) {
    parametersBlock <- c(parametersBlock, sprintf("real Bage%d", j))
    
    if (!skipPriors)
      modelBlock <- c(modelBlock, sprintf("Bage%d ~ normal(1, 1)", j))
  }
  
  j <- j + 1
}

dataBlock <- c("int<lower=1> N", "int<lower=1> k1", "int<lower=1> k2", "int age[N]", "int X1[N,k1]", paste(type, "X2[N,k2]"))

parametersBlock <- c(parametersBlock, if (!expGamma) "real gamma[N]" else "real theta[N]")

if (expGamma)
  transformedParametersBlock <- c(transformedParametersBlock, "real gamma[N]", "gamma = exp(theta)")

mdl <- paste0(
  "data {\n", paste0("\t", dataBlock, ";\n", collapse = ""), "}\n\n",
  "parameters {\n", paste0("\t", parametersBlock, ";\n", collapse = ""), "}\n\n",
  "transformed parameters {\n", paste0("\t", transformedParametersBlock, ";\n", collapse = ""), "}\n\n",
  "model {\n", paste0("\t", modelBlock, ";\n", collapse = ""), "\n",
  "\tfor (i in 1:N) {\n", paste0("\t\t", eqs, ";\n", collapse = ""), "\t}\n}\n", collapse = "")

cat(mdl)

id <- paste0(format(Sys.time(), "%m%d_%H%M%S_"), processID,
             (if (expGamma) "_exp" else ""),
             (if (standardizeYears) "_stdYr" else ""),
             (if (gammaPrior) "_gPrior" else ""),
             (if (inclAge) "_age"),
             (if (logYears) "_asinhYrs"),
             (if (postsScale) "_postsScale"),
             (if (skipPriors) "_noPriors"),
             (if (important) "_RFimp"))

file.copy("latent_variable_model.R", sprintf("%s_latent_variable_model.R", id))
cat(mdl, file = sprintf("%s_latent_variable_model.stan", id))

if (postsScale) dummies <- notDummies # has to have something to not crash

fit <- stan(model_code = mdl, data = list(
  N = dim(data)[1],
  age = data$entryage,
  k1 = dim(data[, dummies])[2],
  k2 = dim(data[, notDummies])[2],
  X1 = as.matrix(data[, dummies]),
  X2 = as.matrix(data[, notDummies])
), iter = nIterations, warmup = nWarmup, thin = nThinning, chains = nChains)

polex$gamma <- polex$gamma_low95 <- polex$gamma_high95 <- NA
polex$gamma[nonMiss] <- apply(extract(fit)$gamma, 2, mean)
polex$gamma_low95[nonMiss] <- apply(extract(fit)$gamma, 2, quantile, .025)
polex$gamma_high95[nonMiss] <- apply(extract(fit)$gamma, 2, quantile, .975)

save(mdl, polex, fit, id, file = sprintf("%s_fit.Rdata", id))
export(polex, file = sprintf("%s_estimation_file.dta", id), version = 10)

print(fit)


